package it.uniroma2.dtk.main;

import it.uniroma2.dtk.dt.GenericDT;
import it.uniroma2.dtk.op.convolution.ShuffledCircularConvolution;
import it.uniroma2.svd.writer.DenseBinaryMatrix;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.tree.Tree;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Enumeration;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.core.converters.ConverterUtils;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.converters.Saver;
import weka.core.matrix.Matrix;

@SuppressWarnings("static-access")
public class DTBuilder {

	protected static Options options = new Options();
	
	protected enum OutputTypes {dsm, dbm};
	
	protected static OutputTypes output_type = OutputTypes.dsm;
	protected String WekaConverter = "weka.core.converters.ArffSaver";
	protected boolean useWeka = false;
	private boolean useWekaDense = false;

	private boolean verbose = false;
	private long elapsed = 0;
	
	static {
		options.addOption("not_lexicalized", false, "does not consider leaf nodes");
		options.addOption("pos", false, "use pos augmented labels for leaf nodes in (lexicalized) syntactic trees");
		options.addOption(OptionBuilder.withArgName("seed")
				.hasArg()
				.withDescription("use given random seed (default = 0)")
				.create("randomSeed"));
		options.addOption(OptionBuilder.withArgName("size")
				.hasArg()
				.withDescription("use given vector size (default = 4096)")
				.create("vectorSize"));
		options.addOption(OptionBuilder.withArgName("lambda")
				.hasArg()
				.withDescription("use given lambda to weight tree fragments (default = 1)")
				.create("lambda"));
		options.addOption(OptionBuilder.withArgName("input file")
				.hasArg()
				.withDescription("load trees (in Penn Treebank notation) from the given file")
				.isRequired()
				.create("input"));
		options.addOption(OptionBuilder.withArgName("output file")
				.hasArg()
				.withDescription("print distributed trees to the given file")
				.isRequired()
				.create("output"));
		options.addOption(OptionBuilder.withArgName("operation class name")
				.hasArg()
				.withDescription("use given class as vector composition function implementation, default is shuffled circular convolution")
				.create("op"));
		options.addOption(OptionBuilder.withArgName("[dsm|dbm]")
				.hasArg()
				.withDescription("the format of the output file:  dense string matrix (dsm) and dense binary matrix (dbm), default is dsm")
				.create("of"));
		//options.addOption("wekadense", false, "use a weka input type and a weka output converter (default is weka.core.converters.ArffSaver but it can be specified) ");
		options.addOption("weka", false, "use a weka input type and a weka output converter (default is weka.core.converters.ArffSaver but it can be specified) ");
		options.addOption(OptionBuilder.withArgName("weka format name")
				.hasArg()
				.withDescription("select the weka output format: use the full name of the weka.core.converters.AbstractFileSaver.")
				.create("wekaconverter"));
		options.addOption("verbose", false, "print messages");
	}
	
	/**
	 * @param args
	 */
	public static void main(String[] args) {
		try {
			//System.setProperty("fft2", "");
			DTBuilder dtb = new DTBuilder(args);
			dtb.process();
		} catch (ParseException e) {
			System.err.println( "Parsing failed: " + e.getMessage() );
			HelpFormatter formatter = new HelpFormatter();
			formatter.printHelp( "DTBuilder", options );
		} catch (NumberFormatException e) {
			System.err.println( "Parsing numeric fields failed: " + e.getMessage() );
		} catch (Exception e) {
			e.printStackTrace();
		}
		System.exit(1);
	}
	
	protected GenericDT dt = null;
	protected File inputFile = null;
	protected File outputFile = null;
	protected File outputFileType = null;
	
	private DTBuilder(String[] args) throws ParseException, NumberFormatException, Exception {
		CommandLineParser parser = new GnuParser();
		CommandLine line = parser.parse(options, args);
		
		dt = new GenericDT(line.hasOption("randomSeed") ? Integer.parseInt(line.getOptionValue("randomSeed")) : 0, 
						line.hasOption("vectorSize") ? Integer.parseInt(line.getOptionValue("vectorSize")) : 4096, 
						line.hasOption("pos"), 
						!line.hasOption("not_lexicalized"), 
						line.hasOption("lambda") ? Double.parseDouble(line.getOptionValue("lambda")) : 1,
						line.hasOption("op") ? Class.forName(line.getOptionValue("op")) : ShuffledCircularConvolution.class);
		
		if (line.hasOption("of")) output_type = OutputTypes.valueOf(line.getOptionValue("of"));
		if (line.hasOption("weka")) useWeka = true;
		if (line.hasOption("verbose")) verbose = true;
		if (line.hasOption("wekadense")) useWekaDense  = true;
		if (line.hasOption("wekaconverter")) {
			WekaConverter = line.getOptionValue("wekaconverter");
			Object o = Class.forName(WekaConverter).newInstance();
			if (!(o instanceof Saver)) {
				throw new Exception("Error \n\t" + WekaConverter + "\nThis weka converter does not exist!" );
			}
		}
		inputFile = new File(line.getOptionValue("input"));
		outputFile = new File(line.getOptionValue("output"));
	}

	
	private void process() throws Exception {
		if (useWeka) process_weka();
		else if (useWekaDense) process_weka_matrix_dense();
		else if (output_type == OutputTypes.dsm ) process_dsm();
		else if (output_type == OutputTypes.dbm) process_dbm();
		else throw new Exception("Unknown output type");
		if (verbose) System.out.println("Elapsed time (for computing distributed trees) = " + elapsed +"ms"); 
	}
	
	


	private void process_dsm() throws Exception {
		BufferedReader input = new BufferedReader(new FileReader(inputFile));
		FileWriter output = new FileWriter(outputFile);
		while (true) {
			String treeString = input.readLine();
			if (treeString == null)
				break;
			//ORIGINAL output.write(ArrayMath.arrayToString(dt.dt(Tree.fromPennTree(treeString))) + "\n");

			Tree in = Tree.fromPennTree(treeString);
			long begin = System.currentTimeMillis();
			double [] v  = dt.dt(in);
			long end = System.currentTimeMillis();
			elapsed += (end-begin); 

			output.write(ArrayMath.arrayToString(v) + "\n");
			
			System.out.print('.');
		}
		System.out.println("done");
		input.close();
		output.close();
	}

	private void process_dbm() throws Exception {
		BufferedReader input = new BufferedReader(new FileReader(inputFile));
		DenseBinaryMatrix<Float> output = new DenseBinaryMatrix<Float>(Float.class);
		output.openFile(outputFile.getAbsolutePath(), "rw");
		output.setCol(dt.getVectorSize());
		int row = 0;
		long time_only_dt = 0;
		long time_full    = 0;
		while (true) {
			String treeString = input.readLine();
			if (treeString == null)
				break;
			long start = System.currentTimeMillis();
			double[] v = dt.dt(Tree.fromPennTree(treeString));
			long only_dt_end= System.currentTimeMillis();
			output.setFullRowFloat(row, ArrayMath.convertToFloatArray(v));
			long full_end= System.currentTimeMillis();
			time_only_dt += only_dt_end - start;
			time_full += full_end - start;
			row++;
			if ((row)%1000 == 0) {
				System.out.println((row-1) + "\tdt processing time (1000rows) =\t" + time_only_dt + "\tfull processing time (1000rows) =\t" + time_full);
				time_only_dt = 0;
				time_full = 0;
			}
		}
		System.out.println(row + "\tdt processing time ("+ row%1000 +"rows) =\t" + time_only_dt + "\tfull processing time ("+ row%1000 +"rows) =\t" + time_full);
		System.out.println("done");
		input.close();
		output.setRows(row);
		output.writeSize();
		output.closeFile();
		
	}

	private void process_weka_matrix_dense() throws Exception {
		BufferedReader input = new BufferedReader(new FileReader(inputFile));
		int row = 0;
		while (true) {
			String treeString = input.readLine();
			if (treeString == null) break;
			row++;
		}
		input.close();
		
		input = new BufferedReader(new FileReader(inputFile));
		
		double [][] matrix = new double[row][dt.getVectorSize()];
		row = 0;
		while (true) {
			String treeString = input.readLine();
			if (treeString == null)
				break;
			matrix[row] = dt.dt(Tree.fromPennTree(treeString));
			System.out.print('.');
			row++;
		}
		input.close();
		
		System.out.println("Saving");
		Matrix m = new Matrix(matrix);
		m.write(new FileWriter(outputFile));
		System.out.println("Done");
		
	}
	
	
	private void process_weka() throws Exception {
		DataSource input = new ConverterUtils.DataSource(inputFile.getAbsolutePath());
		Instances input_instances = input.getStructure();
		@SuppressWarnings("unchecked")
		Enumeration<Attribute> attributes = (Enumeration<Attribute>) input_instances.enumerateAttributes();
		ArrayList<Attribute> trees = new ArrayList<Attribute>();
		ArrayList<Attribute> other_attributes = new ArrayList<Attribute>();
		
		
		while (attributes.hasMoreElements()) {
			Attribute a = attributes.nextElement();
			if (a.name().endsWith(":tree")) trees.add(a);
			else other_attributes.add(a);
		}
		
		
		ArrayList<Attribute> dimensions = new ArrayList<Attribute>();
		
		for (Attribute a:trees) 
			for (int i=0;i<dt.getVectorSize();i++) dimensions.add(new Attribute( a.name() + "_" + i));

		for (Attribute a:other_attributes) {
			dimensions.add(a);
		}

		ArrayList<Instance> instances = new ArrayList<Instance>(); 

		while (true) {
			Instance instance = input.nextElement(input_instances) ;
			if (instance == null)
				break;
			Instance subpart_of_trees = null;
			for (Attribute a:trees) {
				Instance i = new DenseInstance(1,dt.dt(Tree.fromPennTree(instance.stringValue(a))));
				if (subpart_of_trees == null) subpart_of_trees = i;
				else subpart_of_trees = subpart_of_trees.mergeInstance(i);
			}
			Instance other_attribute_values = new SparseInstance(other_attributes.size());
			int num = 0;
			for (Attribute a:other_attributes) {
				if (a.isString()) other_attribute_values.setValue(num, instance.stringValue(a));
				else other_attribute_values.setValue(num, instance.value(a));
				num++;
			}
			Instance final_instance = null;
			if (subpart_of_trees!=null) final_instance = subpart_of_trees.mergeInstance(other_attribute_values);
			else final_instance =  other_attribute_values;
			instances.add(final_instance);
			System.out.print('.');
		}
		Instances instances_2 = new Instances( input_instances.relationName() + "_dt", dimensions, input_instances.size());
		for (Instance i:instances) instances_2.add(i);
		
		Saver output = (Saver) Class.forName(WekaConverter).newInstance();
		output.setFile(new File(outputFile.getAbsoluteFile() + output.getFileExtension()));
		ConverterUtils.DataSink.write(output,instances_2);
		
	}

}
